#!/usr/bin/env python3
############################################################
# Program is part of MintPy                                #
# Copyright (c) 2013, Zhang Yunjun, Heresh Fattahi         #
# Author: Zhang Yunjun, 2017                               #
############################################################


import os
import sys
import time
import argparse
import numpy as np

from mintpy.objects.resample import resample
from mintpy.defaults.template import get_template_content
from mintpy.utils import (
    arg_group,
    readfile,
    writefile,
    utils as ut,
    attribute as attr,
)


######################################################################################
TEMPLATE = get_template_content('geocode')

EXAMPLE = """example:
  geocode.py velocity.h5
  geocode.py velocity.h5 -b -0.5 -0.25 -91.3 -91.1
  geocode.py velocity.h5 timeseries.h5 -t smallbaselineApp.cfg --outdir ./geo --update

  # geocode file using ISCE-2 lat/lon.rdr file
  geocode.py filt_fine.int --lat-file ../../geom_reference/lat.rdr --lon-file ../../geom_reference/lon.rdr

  # radar-code file in geo coordinates
  geocode.py swbdLat_S02_N01_Lon_W092_W090.wbd -l geometryRadar.h5 -o waterMask.rdr --geo2radar
  geocode.py geo_velocity.h5 --geo2radar
"""

DEG2METER = """
degrees     --> meters on equator
0.000925926 --> 100
0.000833334 --> 90
0.000555556 --> 60
0.000462963 --> 50
0.000277778 --> 30
0.000185185 --> 20
0.000092593 --> 10
"""


def create_parser():
    parser = argparse.ArgumentParser(description='Resample radar coded files into geo coordinates, or reverse',
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog=TEMPLATE+'\n'+EXAMPLE)

    parser.add_argument('file', nargs='+', help='File(s) to be geocoded')
    parser.add_argument('-d', '--dset', help='dataset to be geocoded, for example:\n' +
                        'height                        for geometryRadar.h5\n' +
                        'unwrapPhase-20100114_20101017 for ifgramStack.h5')

    parser.add_argument('-l', '--lookup', dest='lookupFile',
                        help='Lookup table file generated by InSAR processors.')
    parser.add_argument('--lat-file', dest='latFile', help='lookup table file for latitude.')
    parser.add_argument('--lon-file', dest='lonFile', help='lookup table file for longitude.')

    parser.add_argument('--geo2radar', '--geo2rdr', dest='radar2geo', action='store_false',
                        help='resample geocoded files into radar coordinates.\n' +
                             'ONLY for lookup table in radar-coord (ISCE, Doris).')

    parser.add_argument('-t', '--template', dest='templateFile',
                        help="Template file with geocoding options.")

    # output grid / geometry
    out = parser.add_argument_group('grid in geo-coordinates')
    out.add_argument('-b', '--bbox', dest='SNWE', type=float, nargs=4, metavar=('S', 'N', 'W', 'E'),
                     help='Bounding box for the area of interest.\n'
                          'using coordinates of the uppler left corner of the first pixel\n'
                          '                 and the lower right corner of the last pixel\n'
                          "for radar2geo, it's the output spatial extent\n"
                          "for geo2radar, it's the input  spatial extent")
    out.add_argument('--lalo', '--lalo-step', dest='laloStep', type=float, nargs=2, metavar=('LAT_STEP', 'LON_STEP'),
                     help='output pixel size in degree in latitude / longitude.{}'.format(DEG2METER))

    # interpolation / resampling
    interp = parser.add_argument_group('interpolation')
    interp.add_argument('-i', '--interp', dest='interpMethod', default='nearest', choices={'nearest', 'linear'},
                        help='interpolation/resampling method (default: %(default)s).')
    interp.add_argument('--fill', dest='fillValue', type=float, default=np.nan,
                        help='Fill value for extrapolation (default: %(default)s).')
    interp.add_argument('-n','--nprocs', dest='nprocs', type=int, default=1,
                        help='number of processors to be used for calculation (default: %(default)s).\n'
                             'Note: Do not use more processes than available processor cores.')
    interp.add_argument('--software', dest='software', default='pyresample', choices={'pyresample', 'scipy'},
                        help='software/module used for interpolation (default: %(default)s)\n'
                             'Note: --bbox is not supported for -p scipy')

    parser.add_argument('--update', dest='updateMode', action='store_true',
                        help='skip resampling if output file exists and newer than input file')
    parser.add_argument('-o', '--output', dest='outfile',
                        help="output file name. Default: add prefix 'geo_'")
    parser.add_argument('--outdir', '--output-dir', dest='out_dir', help='output directory.')

    # computing
    parser = arg_group.add_memory_argument(parser)

    return parser


def cmd_line_parse(iargs=None):
    parser = create_parser()
    inps = parser.parse_args(args=iargs)

    if inps.templateFile:
        inps = read_template2inps(inps.templateFile, inps)

    inps = _check_inps(inps)

    return inps


def _check_inps(inps):
    # check 1 - input file(s) existence
    inps.file = ut.get_file_list(inps.file)
    if not inps.file:
        raise Exception('ERROR: no input file found!')
    elif len(inps.file) > 1:
        inps.outfile = None

    # check 2 - lookup table existence
    if not inps.lookupFile:
        # grab default lookup table
        inps.lookupFile = ut.get_lookup_file(inps.lookupFile)

        # use isce-2 lat/lon.rdr file
        if not inps.lookupFile and inps.latFile:
            inps.lookupFile = inps.latFile

        # final check
        if not inps.lookupFile:
            raise FileNotFoundError('No lookup table found! Can not geocode without it.')

    # check 3 - src file coordinate & radar2geo operatioin
    atr = readfile.read_attribute(inps.file[0])
    if 'Y_FIRST' in atr.keys() and inps.radar2geo:
        print('input file is already geocoded')
        print('to resample geocoded files into radar coordinates, use --geo2radar option')
        print('exit without doing anything.')
        sys.exit(0)
    elif 'Y_FIRST' not in atr.keys() and not inps.radar2geo:
        print('input file is already in radar coordinates, exit without doing anything')
        sys.exit(0)

    # check 4 - laloStep
    # valid only if:
    # 1. radar2geo = True AND
    # 2. lookup table is in radar coordinates
    if inps.laloStep:
        if not inps.radar2geo:
            print('ERROR: --lalo-step can NOT be used together with --geo2radar!')
            sys.exit(0)
        atr = readfile.read_attribute(inps.lookupFile)
        if 'Y_FIRST' in atr.keys():
            print('ERROR: --lalo-step can NOT be used with lookup table file in geo-coordinates!')
            sys.exit(0)

    # check 5 - number of processors for multiprocessingg
    inps.nprocs = check_num_processor(inps.nprocs)
    return inps


def read_template2inps(template_file, inps):
    """Read input template options into Namespace inps"""
    print('read input option from template file: ' + template_file)
    if not inps:
        inps = cmd_line_parse()
    inps_dict = vars(inps)
    template = readfile.read_template(template_file)
    template = ut.check_template_auto_value(template)

    prefix = 'mintpy.geocode.'
    key_list = [i for i in list(inps_dict.keys()) if prefix + i in template.keys()]
    for key in key_list:
        value = template[prefix + key]
        if value:
            if key in ['SNWE', 'laloStep']:
                value = value.replace('[','').replace(']','').replace(',',' ')
                inps_dict[key] = [float(i) for i in value.split()]
            elif key in ['interpMethod']:
                inps_dict[key] = value
            elif key == 'fillValue':
                if 'nan' in value.lower():
                    inps_dict[key] = np.nan
                else:
                    inps_dict[key] = float(value)

    # computing configurations
    key = 'mintpy.compute.maxMemory'
    if key in template.keys() and template[key]:
        inps.maxMemory = float(template[key])

    return inps


############################################################################################
def check_num_processor(nprocs):
    """Check number of processors
    Note by Yunjun, 2019-05-02:
    1. conda install pyresample will install pykdtree and openmp, but it seems not working
        geocode.py is getting slower with more processors
    2. macports seems to have minor speedup when more processors
    Thus, default number of processors is set to 1; although the capability of using multiple
    processors is written here.
    """
    if not nprocs:
        #OMP_NUM_THREADS is defined in environment variable for OpenMP
        if 'OMP_NUM_THREADS' in os.environ:
            nprocs = int(os.getenv('OMP_NUM_THREADS'))
        else:
            nprocs = int(os.cpu_count() / 2)
    nprocs = min(os.cpu_count(), nprocs)
    print('number of processor to be used: {}'.format(nprocs))
    return nprocs


def auto_output_filename(infile, inps):
    if len(inps.file) == 1 and inps.outfile:
        return inps.outfile

    if inps.radar2geo:
        prefix = 'geo_'
    else:
        prefix = 'rdr_'

    if inps.dset:
        outfile = '{}{}.h5'.format(prefix, inps.dset)
    else:
        outfile = '{}{}'.format(prefix, os.path.basename(infile))

    if inps.out_dir:
        if not os.path.isdir(inps.out_dir):
            os.makedirs(inps.out_dir)
            print('create directory: {}'.format(inps.out_dir))
        outfile = os.path.join(inps.out_dir, outfile)
    return outfile


def run_geocode(inps):
    """geocode all input files"""
    start_time = time.time()

    # feed the largest file for resample object initiation
    ind_max = np.argmax([os.path.getsize(i) for i in inps.file])

    # prepare geometry for geocoding
    kwargs = dict(interp_method=inps.interpMethod,
                  fill_value=inps.fillValue,
                  nprocs=inps.nprocs,
                  max_memory=inps.maxMemory,
                  software=inps.software,
                  print_msg=True)
    if inps.latFile and inps.lonFile:
        kwargs['lat_file'] = inps.latFile
        kwargs['lon_file'] = inps.lonFile
    res_obj = resample(lut_file=inps.lookupFile,
                       src_file=inps.file[ind_max],
                       SNWE=inps.SNWE,
                       lalo_step=inps.laloStep,
                       **kwargs)
    res_obj.open()
    res_obj.prepare()

    # resample input files one by one
    for infile in inps.file:
        print('-' * 50+'\nresampling file: {}'.format(infile))
        atr = readfile.read_attribute(infile, datasetName=inps.dset)
        outfile = auto_output_filename(infile, inps)

        # update_mode
        if inps.updateMode:
            print('update mode: ON')
            if ut.run_or_skip(outfile, in_file=[infile, inps.lookupFile]) == 'skip':
                continue

        ## prepare output
        # update metadata
        if inps.radar2geo:
            atr = attr.update_attribute4radar2geo(atr, res_obj=res_obj)
        else:
            atr = attr.update_attribute4geo2radar(atr, res_obj=res_obj)

        # instantiate output file
        file_is_hdf5 = os.path.splitext(outfile)[1] in ['.h5', '.he5']
        if file_is_hdf5:
            compression = readfile.get_hdf5_compression(infile)
            writefile.layout_hdf5(outfile, metadata=atr, ref_file=infile, compression=compression)
        else:
            dsDict = dict()

        ## run
        dsNames = readfile.get_dataset_list(infile, datasetName=inps.dset)
        maxDigit = max([len(i) for i in dsNames])
        for dsName in dsNames:

            if not file_is_hdf5:
                dsDict[dsName] = np.zeros((res_obj.length, res_obj.width))

            # loop for block-by-block IO
            for i in range(res_obj.num_box):
                src_box = res_obj.src_box_list[i]
                dest_box = res_obj.dest_box_list[i]

                # read
                print('-'*50+'\nreading {d:<{w}} in block {b} from {f} ...'.format(
                    d=dsName, w=maxDigit, b=src_box, f=os.path.basename(infile)))

                data = readfile.read(infile,
                                     datasetName=dsName,
                                     box=src_box,
                                     print_msg=False)[0]

                # resample
                data = res_obj.run_resample(src_data=data, box_ind=i)

                # write / save block data
                if data.ndim == 3:
                    block = [0, data.shape[0],
                             dest_box[1], dest_box[3],
                             dest_box[0], dest_box[2]]
                else:
                    block = [dest_box[1], dest_box[3],
                             dest_box[0], dest_box[2]]

                if file_is_hdf5:
                    print('write data in block {} to file: {}'.format(block, outfile))
                    writefile.write_hdf5_block(outfile,
                                               data=data,
                                               datasetName=dsName,
                                               block=block,
                                               print_msg=False)
                else:
                    dsDict[dsName][block[0]:block[1],
                                   block[2]:block[3]] = data

            # for binary file: ensure same data type
            if not file_is_hdf5:
                dsDict[dsName] = np.array(dsDict[dsName], dtype=data.dtype)

        # write binary file
        if not file_is_hdf5:
            atr['BANDS'] = len(dsDict.keys())
            writefile.write(dsDict, out_file=outfile, metadata=atr, ref_file=infile)

            # create ISCE XML and GDAL VRT file if using ISCE lookup table file
            if inps.latFile and inps.lonFile:
                writefile.write_isce_xml(atr, fname=outfile)

    m, s = divmod(time.time()-start_time, 60)
    print('time used: {:02.0f} mins {:02.1f} secs.\n'.format(m, s))
    return outfile


######################################################################################
def main(iargs=None):
    inps = cmd_line_parse(iargs)

    run_geocode(inps)

    return


######################################################################################
if __name__ == '__main__':
    main(sys.argv[1:])
